# Causal Bayesian Network
import time
from typing import Any, Optional

import networkx as nx
import pandas as pd
from custom_models.CustomCausalModel import CustomCausalModel
from pgmpy.estimators import BayesianEstimator, MaximumLikelihoodEstimator
from pgmpy.inference import CausalInference
from pgmpy.models.BayesianModel import BayesianNetwork

from utils.graph import get_dag_from_causal_graph


class CBN(CustomCausalModel):
    def __init__(self, causal_graph: nx.DiGraph):
        # We need to get one dag from the equivalence class described by the pdag
        causal_graph = get_dag_from_causal_graph(causal_graph)
        self.model = BayesianNetwork(causal_graph)

        return

    def identify_effect(self, treatment: str, outcome: str):
        return self.inference.identify_effect(treatment, outcome)

    def get_states(self, data: pd.DataFrame):
        states = {}
        for column in data.columns:
            states[column] = list(data[column].unique())
        return states

    def fit(
        self,
        data: Optional[pd.DataFrame] = None,
        int_table: Optional[pd.DataFrame] = None,
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
        outcome: Optional[str] = None,
        treatment: Optional[dict[str, float]] = {},
        evidence: dict[str, float] = {},
    ) -> dict[str, Any]:

        # Get states
        states = self.get_states(data)

        # Calibrate all CPDs using a Maximum Likelihood Estimator or bayesian estimator
        # Remove from dataframe those columns that are not present in the graph
        data = data[list(self.model.nodes())].astype('Float32') # We need Float32 because 64 bits require way too much memory

        print("Fitting model")
        if method_params["estimator"] == "BayesianEstimator":
            estimator = BayesianEstimator
        elif method_params["estimator"] == "MaximumLikelihoodEstimator":
            estimator = MaximumLikelihoodEstimator

        start_train_time = time.process_time()
        self.model.fit(
            data, estimator=estimator, prior_type="K2", state_names=states # Alternatively, "BDeu" can be used as the prior_type
        )  # BDeu is equivalent to a uniform prior.
        delta_train_time = time.process_time() - start_train_time

        print(f"Causal Bayesian Network fitted in {delta_train_time} CPU seconds")

        runtime = {"Training Time": delta_train_time}
        return runtime

    def estimate_effect(
        self,
        outcome: str,
        treatment: Optional[dict[str, float]] = {},
        evidence: dict[str, float] = {},
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
        data: Optional[pd.DataFrame] = None,
        int_table: Optional[pd.DataFrame] = None,
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        
        # Convert the treatment and evidence dictionaries to the format expected by the do-calculus
        treatment_dict, control_dict = self.extract_treat_control(treatment)

        for ev in [{}, evidence]:
            print(
                f"Estimating effect of {treatment} on {outcome} conditioned on {ev}"
            )

            start_estimate_time = time.process_time()
            # Use do-calculus to estimate the causal effect
            infer = CausalInference(self.model)

            treated_query = infer.query(
                variables=[outcome],
                do=treatment_dict,
                evidence=ev,
            )
            untreated_query = infer.query(variables=[outcome], do=control_dict, evidence=ev)

            # Compute averages
            untreated_avg = expected_value(
                untreated_query.values, untreated_query.state_names[outcome]
            )
            int_avg = expected_value(
                treated_query.values, treated_query.state_names[outcome]
            )

            if ev == {}:
                ate = int_avg - untreated_avg
                delta_estimate_time_ate = time.process_time() - start_estimate_time
                # Sample from the interventional distribution (Needed for MMD)
                int_samples = self.model.simulate(n_samples = 10000, do = treatment_dict)
            else:
                cate = int_avg - untreated_avg
                delta_estimate_time_cate = time.process_time() - start_estimate_time
                conditional_treated_query = treated_query
                # Sample from the Conditional interventional distribution (Needed for MMD)
                conditional_int_samples = self.model.simulate(n_samples = 10000, do = treatment_dict, evidence = ev)

            


        results = {
            "target": outcome,
            "state_names": treated_query.state_names,
            "Interventional Distribution": treated_query.values,
            "Conditional Interventional Distribution": conditional_treated_query.values,
            "ATE": ate,
            "evidence": evidence if evidence != {} else None,
            "CATE": cate if evidence != {} else None,
            "Interventional Samples": int_samples,
            "Conditional Interventional Samples": conditional_int_samples
        }

        runtime = {
            "Estimation Time ATE": delta_estimate_time_ate,
            "Estimation Time CATE": delta_estimate_time_cate
        }

        return results, runtime




def expected_value(state_probs, values):
    if len(state_probs) != len(values):
        raise ValueError("The number of state probabilities and values must be equal.")

    return sum([state_probs[i] * float(values[i]) for i in range(len(values))])
